{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/workspace/codenames-rl/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO 05-13 07:53:32 [importing.py:53] Triton module has been replaced with a placeholder.\n",
      "INFO 05-13 07:53:32 [__init__.py:239] Automatically detected platform cuda.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-05-13 07:53:36,221\tINFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import json\n",
    "import art\n",
    "from art.local import LocalBackend\n",
    "from dotenv import load_dotenv\n",
    "\n",
    "load_dotenv()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "async def rollout(model: art.TrainableModel, prompt: str) -> art.Trajectory:\n",
    "    messages: art.Messages = [\n",
    "        {\n",
    "            \"role\": \"user\",\n",
    "            \"content\": prompt,\n",
    "        }\n",
    "    ]\n",
    "    client = model.openai_client()\n",
    "    chat_completion = await client.chat.completions.create(\n",
    "        messages=messages,\n",
    "        model=model.name,\n",
    "        max_tokens=100,\n",
    "        timeout=100,\n",
    "        extra_body={\"chat_template_kwargs\": {\"enable_thinking\": False}},\n",
    "    )\n",
    "    choice = chat_completion.choices[0]\n",
    "    content = choice.message.content\n",
    "    print(content)\n",
    "    assert isinstance(content, str)\n",
    "    if content == \"yes\":\n",
    "        reward = 0.5\n",
    "    elif content == \"no\":\n",
    "        reward = 0.75\n",
    "    elif content == \"maybe\":\n",
    "        reward = 1.0\n",
    "    else:\n",
    "        reward = 0.0\n",
    "    return art.Trajectory(messages_and_choices=[*messages, choice], reward=reward)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[\"respond with 'yes', 'no', 'maybe'\", \"respond with 'maybe', 'yes', 'no'\", \"respond with 'no', 'yes', 'maybe'\", \"respond with 'yes', 'maybe', 'no'\", 'respond with yes or no', 'respond with maybe or no', 'respond with no or maybe', 'respond with no or yes', 'respond with yes or no', 'respond with yes, no, maybe', 'respond with maybe, yes, no', 'respond with no, yes, maybe', 'respond with yes, maybe, no', 'respond with yes or no', 'respond with maybe or no', 'respond with no or maybe', 'respond with no or yes', 'respond with yes or no', \"just respond with 'yes', 'no', 'maybe'\", \"just respond with 'maybe', 'yes', 'no'\", \"just respond with 'no', 'yes', 'maybe'\", \"just respond with 'yes', 'maybe', 'no'\", 'just respond with yes or no', 'just respond with maybe or no', 'just respond with no or maybe', 'just respond with no or yes', 'just respond with yes or no', 'just respond with yes, no, maybe', 'just respond with maybe, yes, no', 'just respond with no, yes, maybe', 'just respond with yes, maybe, no', 'just respond with yes or no', 'just respond with maybe or no', 'just respond with no or maybe', 'just respond with no or yes', 'just respond with yes or no']\n"
     ]
    }
   ],
   "source": [
    "with open(\"prompts.json\", \"r\") as f:\n",
    "    prompts = json.load(f)\n",
    "print(prompts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "backend = LocalBackend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/workspace/ART/src/art/__init__.py:11: UserWarning: WARNING: Unsloth should be imported before transformers, peft to ensure all optimizations are applied. Your code may run slower or encounter memory issues without these optimizations.\n",
      "\n",
      "Please restructure your imports with 'import unsloth' at the top of your file.\n",
      "  import unsloth  # type: ignore\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.\n",
      "🦥 Unsloth Zoo will now patch everything to make training faster!\n",
      "INFO 05-13 07:55:12 [importing.py:53] Triton module has been replaced with a placeholder.\n",
      "INFO 05-13 07:55:12 [__init__.py:239] Automatically detected platform cuda.\n",
      "Standard import failed for UnslothBCOTrainer: No module named 'UnslothBCOTrainer'. Using tempfile instead!\n",
      "==((====))==  Unsloth 2025.5.1: Fast Qwen2 patching. Transformers: 4.51.3. vLLM: 0.8.5.post1.\n",
      "   \\\\   /|    NVIDIA H100 80GB HBM3. Num GPUs = 1. Max memory: 79.205 GB. Platform: Linux.\n",
      "O^O/ \\_/ \\    Torch: 2.6.0+cu124. CUDA: 9.0. CUDA Toolkit: 12.4. Triton: 3.2.0\n",
      "\\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post3. FA2 = False]\n",
      " \"-____-\"     Free license: http://github.com/unslothai/unsloth\n",
      "Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!\n",
      "Unsloth: vLLM loading unsloth/qwen2.5-0.5b-instruct-unsloth-bnb-4bit with actual GPU utilization = 78.4%\n",
      "Unsloth: Your GPU has CUDA compute capability 9.0 with VRAM = 79.2 GB.\n",
      "Unsloth: Using conservativeness = 1.0. Chunked prefill tokens = 32768. Num Sequences = 368.\n",
      "Unsloth: vLLM's KV Cache can use up to 61.63 GB. Also swap space = 6 GB.\n",
      "INFO 05-13 07:55:46 [config.py:717] This model supports multiple tasks: {'classify', 'generate', 'score', 'embed', 'reward'}. Defaulting to 'generate'.\n",
      "Unsloth: vLLM Bitsandbytes config using kwargs = {'load_in_8bit': False, 'load_in_4bit': True, 'bnb_4bit_compute_dtype': 'bfloat16', 'bnb_4bit_quant_storage': 'uint8', 'bnb_4bit_quant_type': 'nf4', 'bnb_4bit_use_double_quant': True, 'llm_int8_enable_fp32_cpu_offload': False, 'llm_int8_has_fp16_weight': False, 'llm_int8_skip_modules': ['lm_head', 'multi_modal_projector', 'merger', 'modality_projection', 'model.layers.0.self_attn', 'model.layers.0.mlp', 'model.layers.2.mlp', 'model.layers.3.mlp', 'model.layers.21.mlp', 'model.layers.0.self_attn.q_proj'], 'llm_int8_threshold': 6.0}\n",
      "INFO 05-13 07:55:46 [llm_engine.py:240] Initializing a V0 LLM engine (v0.8.5.post1) with config: model='unsloth/qwen2.5-0.5b-instruct-unsloth-bnb-4bit', speculative_config=None, tokenizer='unsloth/qwen2.5-0.5b-instruct-unsloth-bnb-4bit', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=32768, download_dir=None, load_format=LoadFormat.BITSANDBYTES, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=bitsandbytes, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda:0, decoding_config=DecodingConfig(guided_decoding_backend='auto', reasoning_backend=None), observability_config=ObservabilityConfig(show_hidden_metrics=False, otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=unsloth/qwen2.5-0.5b-instruct-unsloth-bnb-4bit, num_scheduler_steps=16, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=False, use_async_output_proc=True, disable_mm_preprocessor_cache=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={\"level\":0,\"splitting_ops\":[],\"compile_sizes\":[],\"cudagraph_capture_sizes\":[368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],\"max_capture_size\":368}, use_cached_outputs=False, \n",
      "INFO 05-13 07:55:47 [cuda.py:292] Using Flash Attention backend.\n",
      "INFO 05-13 07:55:48 [parallel_state.py:1004] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0\n",
      "INFO 05-13 07:55:48 [model_runner.py:1108] Starting to load model unsloth/qwen2.5-0.5b-instruct-unsloth-bnb-4bit...\n",
      "INFO 05-13 07:55:48 [loader.py:1187] Loading weights with BitsAndBytes quantization. May take a while ...\n",
      "INFO 05-13 07:55:48 [weight_utils.py:265] Using model weights format ['*.safetensors']\n",
      "INFO 05-13 07:55:49 [weight_utils.py:315] No model.safetensors.index.json found in remote.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]\n",
      "Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  4.34it/s]\n",
      "Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  4.33it/s]\n",
      "\n",
      "Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]\n",
      "Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  4.21it/s]\n",
      "Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  4.21it/s]\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO 05-13 07:55:49 [punica_selector.py:18] Using PunicaWrapperGPU.\n",
      "INFO 05-13 07:55:50 [model_runner.py:1140] Model loading took 0.5153 GiB and 1.320794 seconds\n",
      "INFO 05-13 07:55:52 [worker.py:287] Memory profiling takes 2.51 seconds\n",
      "INFO 05-13 07:55:52 [worker.py:287] the current vLLM instance can use total_gpu_memory (79.20GiB) x gpu_memory_utilization (0.78) = 62.10GiB\n",
      "INFO 05-13 07:55:52 [worker.py:287] model weights take 0.52GiB; non_torch_memory takes 0.16GiB; PyTorch activation peak memory takes 2.07GiB; the rest of the memory reserved for KV Cache is 59.36GiB.\n",
      "INFO 05-13 07:55:53 [executor_base.py:112] # cuda blocks: 324184, # CPU blocks: 32768\n",
      "INFO 05-13 07:55:53 [executor_base.py:117] Maximum concurrency for 32768 tokens per request: 158.29x\n",
      "INFO 05-13 07:55:56 [model_runner.py:1450] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Capturing CUDA graph shapes: 100%|██████████| 49/49 [00:30<00:00,  1.63it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO 05-13 07:56:26 [model_runner.py:1592] Graph capturing finished in 30 secs, took 0.84 GiB\n",
      "INFO 05-13 07:56:26 [llm_engine.py:437] init engine (profile, create kv cache, warmup model) took 36.86 seconds\n",
      "Unsloth: Just some info: will skip parsing ['q_norm', 'post_feedforward_layernorm', 'pre_feedforward_layernorm', 'k_norm']\n",
      "Unsloth: Just some info: will skip parsing ['q_norm', 'post_feedforward_layernorm', 'pre_feedforward_layernorm', 'k_norm']\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Unsloth 2025.5.1 patched 24 layers with 24 QKV layers, 24 O layers and 24 MLP layers.\n"
     ]
    }
   ],
   "source": [
    "qwen2 = art.TrainableModel(\n",
    "    name=\"004\",\n",
    "    project=\"yes-no-maybe-s\",\n",
    "    base_model=\"Qwen/Qwen2.5-0.5B-Instruct\",\n",
    "    # base_model=\"Qwen/Qwen2.5-0.5B-Instruct\",\n",
    ")\n",
    "await qwen2.register(backend)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Yes\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Trajectory(messages_and_choices=[{'content': 'respond with yes or no', 'role': 'user'}, Choice(finish_reason='stop', index=0, logprobs=ChoiceLogprobs(content=[ChatCompletionTokenLogprob(token='token_id:9454', bytes=[89, 101, 115], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token='token_id:151645', bytes=[], logprob=-0.6409199237823486, top_logprobs=[])], refusal=None), message=ChatCompletionMessage(content='Yes', refusal=None, role='assistant', annotations=None, audio=None, function_call=None, tool_calls=[], reasoning_content=None), stop_reason=None)], tools=None, reward=0.0, metrics={}, metadata={}, logs=[], start_time=datetime.datetime(2025, 5, 13, 7, 57, 45, 484026))"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "await rollout(qwen2, prompts[4])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/workspace/ART/src/art/__init__.py:11: UserWarning: WARNING: Unsloth should be imported before transformers, peft to ensure all optimizations are applied. Your code may run slower or encounter memory issues without these optimizations.\n",
      "\n",
      "Please restructure your imports with 'import unsloth' at the top of your file.\n",
      "  import unsloth  # type: ignore\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.\n",
      "🦥 Unsloth Zoo will now patch everything to make training faster!\n",
      "INFO 05-13 08:00:16 [importing.py:53] Triton module has been replaced with a placeholder.\n",
      "INFO 05-13 08:00:17 [__init__.py:239] Automatically detected platform cuda.\n",
      "Standard import failed for UnslothCPOTrainer: No module named 'UnslothCPOTrainer'. Using tempfile instead!\n",
      "Unsloth: Switching from Unsloth dynamic quant to normal quant since\n",
      "we do not yet support fast inference for unsloth/qwen3-0.6b-unsloth-bnb-4bit\n",
      "==((====))==  Unsloth 2025.5.1: Fast Qwen3 patching. Transformers: 4.51.3. vLLM: 0.8.5.post1.\n",
      "   \\\\   /|    NVIDIA H100 80GB HBM3. Num GPUs = 1. Max memory: 79.205 GB. Platform: Linux.\n",
      "O^O/ \\_/ \\    Torch: 2.6.0+cu124. CUDA: 9.0. CUDA Toolkit: 12.4. Triton: 3.2.0\n",
      "\\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post3. FA2 = False]\n",
      " \"-____-\"     Free license: http://github.com/unslothai/unsloth\n",
      "Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!\n",
      "Unsloth: vLLM loading unsloth/qwen3-0.6b-bnb-4bit with actual GPU utilization = 78.4%\n",
      "Unsloth: Your GPU has CUDA compute capability 9.0 with VRAM = 79.2 GB.\n",
      "Unsloth: Using conservativeness = 1.0. Chunked prefill tokens = 32768. Num Sequences = 368.\n",
      "Unsloth: vLLM's KV Cache can use up to 61.59 GB. Also swap space = 6 GB.\n",
      "INFO 05-13 08:00:50 [config.py:717] This model supports multiple tasks: {'embed', 'score', 'reward', 'classify', 'generate'}. Defaulting to 'generate'.\n",
      "Unsloth: vLLM Bitsandbytes config using kwargs = {'load_in_8bit': False, 'load_in_4bit': True, 'bnb_4bit_compute_dtype': 'bfloat16', 'bnb_4bit_quant_storage': 'uint8', 'bnb_4bit_quant_type': 'nf4', 'bnb_4bit_use_double_quant': True, 'llm_int8_enable_fp32_cpu_offload': False, 'llm_int8_has_fp16_weight': False, 'llm_int8_skip_modules': ['lm_head', 'multi_modal_projector', 'merger', 'modality_projection'], 'llm_int8_threshold': 6.0}\n",
      "INFO 05-13 08:00:51 [llm_engine.py:240] Initializing a V0 LLM engine (v0.8.5.post1) with config: model='unsloth/qwen3-0.6b-bnb-4bit', speculative_config=None, tokenizer='unsloth/qwen3-0.6b-bnb-4bit', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=32768, download_dir=None, load_format=LoadFormat.BITSANDBYTES, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=bitsandbytes, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda:0, decoding_config=DecodingConfig(guided_decoding_backend='auto', reasoning_backend=None), observability_config=ObservabilityConfig(show_hidden_metrics=False, otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=unsloth/qwen3-0.6b-bnb-4bit, num_scheduler_steps=16, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=False, use_async_output_proc=True, disable_mm_preprocessor_cache=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={\"level\":0,\"splitting_ops\":[],\"compile_sizes\":[],\"cudagraph_capture_sizes\":[368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],\"max_capture_size\":368}, use_cached_outputs=False, \n",
      "INFO 05-13 08:00:51 [cuda.py:292] Using Flash Attention backend.\n",
      "INFO 05-13 08:00:52 [parallel_state.py:1004] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0\n",
      "INFO 05-13 08:00:52 [model_runner.py:1108] Starting to load model unsloth/qwen3-0.6b-bnb-4bit...\n",
      "INFO 05-13 08:00:52 [loader.py:1187] Loading weights with BitsAndBytes quantization. May take a while ...\n",
      "INFO 05-13 08:00:52 [weight_utils.py:265] Using model weights format ['*.safetensors']\n",
      "INFO 05-13 08:00:52 [weight_utils.py:315] No model.safetensors.index.json found in remote.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]\n",
      "Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  4.84it/s]\n",
      "Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  4.83it/s]\n",
      "\n",
      "Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]\n",
      "Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  4.32it/s]\n",
      "Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  4.31it/s]\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO 05-13 08:00:53 [punica_selector.py:18] Using PunicaWrapperGPU.\n",
      "INFO 05-13 08:00:53 [model_runner.py:1140] Model loading took 0.5271 GiB and 1.011840 seconds\n",
      "INFO 05-13 08:00:55 [worker.py:287] Memory profiling takes 1.93 seconds\n",
      "INFO 05-13 08:00:55 [worker.py:287] the current vLLM instance can use total_gpu_memory (79.20GiB) x gpu_memory_utilization (0.78) = 62.10GiB\n",
      "INFO 05-13 08:00:55 [worker.py:287] model weights take 0.53GiB; non_torch_memory takes 0.16GiB; PyTorch activation peak memory takes 2.07GiB; the rest of the memory reserved for KV Cache is 59.34GiB.\n",
      "INFO 05-13 08:00:56 [executor_base.py:112] # cuda blocks: 34722, # CPU blocks: 3510\n",
      "INFO 05-13 08:00:56 [executor_base.py:117] Maximum concurrency for 32768 tokens per request: 16.95x\n",
      "INFO 05-13 08:00:59 [model_runner.py:1450] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Capturing CUDA graph shapes: 100%|██████████| 49/49 [00:32<00:00,  1.51it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO 05-13 08:01:31 [model_runner.py:1592] Graph capturing finished in 32 secs, took 1.16 GiB\n",
      "INFO 05-13 08:01:31 [llm_engine.py:437] init engine (profile, create kv cache, warmup model) took 37.90 seconds\n",
      "Unsloth: Just some info: will skip parsing ['post_feedforward_layernorm', 'pre_feedforward_layernorm']\n",
      "Unsloth: Just some info: will skip parsing ['post_feedforward_layernorm', 'pre_feedforward_layernorm']\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Unsloth 2025.5.1 patched 28 layers with 28 QKV layers, 28 O layers and 28 MLP layers.\n"
     ]
    }
   ],
   "source": [
    "qwen3 = art.TrainableModel(\n",
    "    name=\"005\",\n",
    "    project=\"yes-no-maybe-s\",\n",
    "    base_model=\"Qwen/Qwen3-0.6B\",\n",
    "    # base_model=\"Qwen/Qwen2.5-0.5B-Instruct\",\n",
    ")\n",
    "await qwen3.register(backend)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Yes.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Trajectory(messages_and_choices=[{'content': 'respond with yes or no', 'role': 'user'}, Choice(finish_reason='stop', index=0, logprobs=ChoiceLogprobs(content=[ChatCompletionTokenLogprob(token='token_id:9454', bytes=[89, 101, 115], logprob=-0.0793837234377861, top_logprobs=[]), ChatCompletionTokenLogprob(token='token_id:13', bytes=[46], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token='token_id:151645', bytes=[], logprob=0.0, top_logprobs=[])], refusal=None), message=ChatCompletionMessage(content='Yes.', refusal=None, role='assistant', annotations=None, audio=None, function_call=None, tool_calls=[], reasoning_content=None), stop_reason=None)], tools=None, reward=0.0, metrics={}, metadata={}, logs=[], start_time=datetime.datetime(2025, 5, 13, 8, 3, 38, 210520))"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "ename": "",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n",
      "\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n",
      "\u001b[1;31mClick <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. \n",
      "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
     ]
    }
   ],
   "source": [
    "await rollout(qwen3, prompts[4])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for _ in range(await qwen3.get_step(), 1_000):\n",
    "    train_groups = await art.gather_trajectory_groups(\n",
    "        (\n",
    "            art.TrajectoryGroup(rollout(qwen3, prompt) for _ in range(32))\n",
    "            for prompt in prompts\n",
    "        ),\n",
    "        pbar_desc=\"gather\",\n",
    "    )\n",
    "    await qwen3.train(\n",
    "        train_groups,\n",
    "        config=art.TrainConfig(learning_rate=1e-4),\n",
    "    )"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
